{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image Classification with Apache Spark\n",
    "\n",
    "In this example, we will use Jupyter Notebook to run image Classification with Apache Spark on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.\n",
    "\n",
    "[Almond installation instruction](https://almond.sh/docs/quick-start-install) (Note: only Scala 2.12 are tested)\n",
    "\n",
    "After that, you can start with DJL's Scala notebook.\n",
    "\n",
    "\n",
    "## Import dependencies\n",
    "\n",
    "Firstly, let's import the depdendencies we need. We choose to use DJL PyTorch as our backend engine. You can also switch to MXNet by uncommenting the two lines for MXNet and comment PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import $ivy.`org.apache.spark::spark-sql:3.0.1`\n",
    "import $ivy.`org.apache.spark::spark-mllib:3.0.1`\n",
    "import $ivy.`ai.djl:api:0.10.0`\n",
    "import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.10.0`\n",
    "import $ivy.`ai.djl.pytorch:pytorch-native-auto:1.7.1`\n",
    "// import $ivy.`ai.djl.mxnet:mxnet-model-zoo:0.10.0`\n",
    "// import $ivy.`ai.djl.mxnet:mxnet-native-auto:1.7.0-backport`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can import the packages we need to use. In the last two lines, we disabled the Spark logging in order to avoid polluting your cell outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util\n",
    "import ai.djl.Model\n",
    "import ai.djl.modality.Classifications\n",
    "import ai.djl.modality.cv.transform.{ Resize, ToTensor}\n",
    "import ai.djl.ndarray.types.{DataType, Shape}\n",
    "import ai.djl.ndarray.{NDList, NDManager}\n",
    "import ai.djl.repository.zoo.{Criteria, ModelZoo, ZooModel}\n",
    "import ai.djl.training.util.ProgressBar\n",
    "import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext}\n",
    "import ai.djl.util.Utils\n",
    "import org.apache.spark.ml.image.ImageSchema\n",
    "import org.apache.spark.sql.functions.col\n",
    "import org.apache.spark.sql.{Encoders, Row, NotebookSparkSession}\n",
    "import org.apache.log4j.{Level, Logger}\n",
    "Logger.getLogger(\"org\").setLevel(Level.OFF) // avoid too much message popping out\n",
    "Logger.getLogger(\"ai\").setLevel(Level.OFF) // avoid too much message popping out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Translator\n",
    "\n",
    "A Translator in DJL is used to define the preprocessing and postprocessing logic. The following code is to \n",
    "\n",
    "- preprocess: convert a Spark DataFrame Row to DJL NDArray.\n",
    "- postprocess: convert inference result to classifications"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "  // Translator: a class used to do preprocessing and post processing\n",
    "  class MyTranslator extends Translator[Row, Classifications] {\n",
    "\n",
    "    private var classes: java.util.List[String] = new util.ArrayList[String]()\n",
    "    private val pipeline: Pipeline = new Pipeline()\n",
    "      .add(new Resize(224, 224))\n",
    "      .add(new ToTensor())\n",
    "\n",
    "    override def prepare(manager: NDManager, model: Model): Unit = {\n",
    "        classes = Utils.readLines(model.getArtifact(\"synset.txt\").openStream())\n",
    "      }\n",
    "\n",
    "    override def processInput(ctx: TranslatorContext, row: Row): NDList = {\n",
    "\n",
    "      val height = ImageSchema.getHeight(row)\n",
    "      val width = ImageSchema.getWidth(row)\n",
    "      val channel = ImageSchema.getNChannels(row)\n",
    "      var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true)\n",
    "      // BGR to RGB\n",
    "      image = image.flip(2)\n",
    "      pipeline.transform(new NDList(image))\n",
    "    }\n",
    "\n",
    "    // Deal with the output.，NDList contains output result, usually one or more NDArray(s).\n",
    "    override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = {\n",
    "      var probabilitiesNd = list.singletonOrThrow\n",
    "      probabilitiesNd = probabilitiesNd.softmax(0)\n",
    "      new Classifications(classes, probabilitiesNd)\n",
    "    }\n",
    "\n",
    "    override def getBatchifier: Batchifier = Batchifier.STACK\n",
    "  }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the model\n",
    "\n",
    "Now, we just need to fetch the model from a URL. The url can be a hdfs (hdfs://), file (file://) or http (https://) format. We use Criteria as a container to store the model and translator information. Then, all we need to do is to load the model from it.\n",
    "\n",
    "Note: DJL Criteria and Model are not serializable, so we add `lazy` declaration.\n",
    "\n",
    "If you are using MXNet as the backend engine, plase uncomment the mxnet model url."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\"\n",
    "// val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/mxnet_resnet18.zip?model_name=resnet18_v1\"\n",
    "lazy val criteria = Criteria.builder\n",
    "  .setTypes(classOf[Row], classOf[Classifications])\n",
    "  .optModelUrls(modelUrl)\n",
    "  .optTranslator(new MyTranslator())\n",
    "  .optProgress(new ProgressBar)\n",
    "  .build()\n",
    "lazy val model = ModelZoo.loadModel(criteria)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Spark application\n",
    "\n",
    "We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Create Spark session\n",
    "val spark = {\n",
    "  NotebookSparkSession.builder()\n",
    "    .master(\"local[*]\")\n",
    "    .getOrCreate()\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try to load the images from the local folder using Spark ML library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val df = spark.read.format(\"image\").option(\"dropInvalid\", true).load(\"../image-classification/images\")\n",
    "df.select(\"image.origin\", \"image.width\", \"image.height\").show(truncate=false)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then We can run inference on these images. All we need to do is to create a `Predictor` and run inference with DJL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val result = df.select(col(\"image.*\")).mapPartitions(partition => {\n",
    "  val predictor = model.newPredictor()\n",
    "  partition.map(row => {\n",
    "    // image data stored as HWC format\n",
    "    predictor.predict(row).toString\n",
    "  })\n",
    "})(Encoders.STRING)\n",
    "println(result.collect().mkString(\"\\n\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Scala",
   "language": "scala",
   "name": "scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".sc",
   "mimetype": "text/x-scala",
   "name": "scala",
   "nbconvert_exporter": "script",
   "version": "2.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
